{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05fa329",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.models import ModelXtoCInception\n",
    "from src.preprocessing.CUB import preprocessing_CUB\n",
    "from src.utils import *\n",
    "from src.training import run_epoch_x_to_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f239f5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 34 classes.\n",
      "Found labels for 2022 images.\n",
      "Generated one-hot matrix with shape: (2022, 34)\n",
      "Total number of concept columns: 28\n",
      "Found 2013 images.\n",
      "Processing in 63 batches of size 32 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 63/63 [00:14<00:00,  4.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 2013 images.\n",
      "Labels shape: (2013, 34)\n",
      "Concepts shape: (2013, 28)\n",
      "Image tensors length: 2013\n",
      "Dataset initialized with 826 pre-sorted items.\n",
      "Dataset initialized with 790 pre-sorted items.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "concept_labels, train_loader, test_loader = preprocessing_CUB(training=True, class_concepts=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c80140bb",
   "metadata": {},
   "source": [
    "# Training Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae643c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import find_class_imbalance\n",
    "from src.config import CUB_CONFIG as config_dict\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3d7846bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIMMED_CONCEPTS = config_dict['N_TRIMMED_CONCEPTS']\n",
    "N_CLASSES = config_dict['N_CLASSES']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaf19306",
   "metadata": {},
   "source": [
    "**Find device to run model on (CPU or GPU).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5316e340",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available()\n",
    "                    else \"mps\" if torch.backends.mps.is_available()\n",
    "                    else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b5b0a8",
   "metadata": {},
   "source": [
    "**Instantiate the model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82fc53a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Instantiated (X -> C)\n"
     ]
    }
   ],
   "source": [
    "model = ModelXtoCInception(pretrained=True,\n",
    "                freeze=True,\n",
    "                n_classes=N_CLASSES, # Still needed for model structure, but won't be trained/used\n",
    "                use_aux=True,\n",
    "                n_concepts=N_TRIMMED_CONCEPTS)\n",
    "\n",
    "model = model.to(device)\n",
    "print(\"Model Instantiated (X -> C)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4731e539",
   "metadata": {},
   "source": [
    "### Loss\n",
    "We use weighted loss.\n",
    "\n",
    "`BCEWithLogitsLoss()` performs 2 steps:\n",
    "1. $\\sigma(x)$\n",
    "    - Applies the sigmoid function to the logits to get probabilities.\n",
    "2. $\\text{BCE}(\\sigma(x), y) = y \\cdot \\text{log}(\\sigma(x)) + (1-y) \\cdot (1-\\text{log}(\\sigma(x)))$\n",
    "    - Compute binary cross-entropy between output probabilities ($\\sigma(x)$) and ground truths ($y$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "be2ad09f",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_weighted_loss = True # Set to False for simple unweighted loss\n",
    "\n",
    "if use_weighted_loss:\n",
    "    concept_weights = find_class_imbalance(concept_labels)\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))\n",
    "                    for ratio in concept_weights]\n",
    "else:\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_TRIMMED_CONCEPTS)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "936c945c",
   "metadata": {},
   "source": [
    "### Optimiser\n",
    "Use same settings as used in CBM repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b50f189a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizer and Scheduler Ready\n"
     ]
    }
   ],
   "source": [
    "lr = 0.01\n",
    "weight_decay = 0.00004 # same as lambda in L2-regularisation\n",
    "\n",
    "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),\n",
    "                    lr=lr,\n",
    "                    momentum=0.9,\n",
    "                    weight_decay=weight_decay)\n",
    "\n",
    "# scheduler_step = n -> decrease the LR every n epochs\n",
    "scheduler_step = 1000\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)\n",
    "\n",
    "print(\"Optimizer and Scheduler Ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e648c53",
   "metadata": {},
   "source": [
    "### Training and Validation Loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5a03d3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 50\n",
    "log_interval = 50\n",
    "\n",
    "best_val_acc = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "743f1447",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Epoch 1/50 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   0%|          | 0/13 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Train Summary | Loss: 153.6553 | Acc: 66.102\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Val Summary   | Loss: 265.9652 | Acc: 75.425\n",
      "Validation accuracy improved (0.000 -> 75.425). Saving model...\n",
      "Model saved to x_to_c_best_model.pth\n",
      "Current LR: 0.01\n",
      "--- Epoch 2/50 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 Train Summary | Loss: 139.8827 | Acc: 76.340\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                       \r"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m      9\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m test_loader:\n\u001b[32m     10\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m         val_loss, val_acc = \u001b[43mrun_epoch_x_to_c\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr_criterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_concepts\u001b[49m\u001b[43m=\u001b[49m\u001b[43mN_TRIMMED_CONCEPTS\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m     13\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch+\u001b[32m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m Val Summary   | Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mval_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mval_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m     15\u001b[39m     \u001b[38;5;66;03m# Save best model based on validation accuracy\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/career/UJM Lab/hybrid-cbm-prototype-model/src/training/phase_1_training.py:50\u001b[39m, in \u001b[36mrun_epoch_x_to_c\u001b[39m\u001b[34m(model, loader, criterion_list, optimizer, n_concepts, is_training, use_aux, device, verbose, return_outputs)\u001b[39m\n\u001b[32m     47\u001b[39m desc = \u001b[33m\"\u001b[39m\u001b[33mTraining\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_training \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mValidation\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m     48\u001b[39m tqdm_loader = tqdm(loader, desc=desc, leave=\u001b[38;5;28;01mFalse\u001b[39;00m, disable=\u001b[38;5;129;01mnot\u001b[39;00m verbose) \u001b[38;5;66;03m# Wrap the loader with tqdm\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtqdm_loader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m     51\u001b[39m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     52\u001b[39m \u001b[43m    \u001b[49m\u001b[43mconcept_labels\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[39m, in \u001b[36mtqdm.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m   1178\u001b[39m time = \u001b[38;5;28mself\u001b[39m._time\n\u001b[32m   1180\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1181\u001b[39m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m   1182\u001b[39m \u001b[43m        \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[32m   1183\u001b[39m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[32m   1184\u001b[39m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:708\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    705\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    706\u001b[39m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m    707\u001b[39m     \u001b[38;5;28mself\u001b[39m._reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m708\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    709\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m    710\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m    711\u001b[39m     \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m    712\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m    713\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m    714\u001b[39m ):\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1446\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m   1443\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   1444\u001b[39m     \u001b[38;5;66;03m# no valid `self._rcvd_idx` is found (i.e., didn't break)\u001b[39;00m\n\u001b[32m   1445\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._persistent_workers:\n\u001b[32m-> \u001b[39m\u001b[32m1446\u001b[39m         \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_shutdown_workers\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1447\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n\u001b[32m   1449\u001b[39m \u001b[38;5;66;03m# Now `self._rcvd_idx` is the batch index we want to fetch\u001b[39;00m\n\u001b[32m   1450\u001b[39m \n\u001b[32m   1451\u001b[39m \u001b[38;5;66;03m# Check if the next sample has already been generated\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1582\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._shutdown_workers\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m   1577\u001b[39m         \u001b[38;5;28mself\u001b[39m._mark_worker_as_unavailable(worker_id, shutdown=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m   1578\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._workers:\n\u001b[32m   1579\u001b[39m     \u001b[38;5;66;03m# We should be able to join here, but in case anything went\u001b[39;00m\n\u001b[32m   1580\u001b[39m     \u001b[38;5;66;03m# wrong, we set a timeout and if the workers fail to join,\u001b[39;00m\n\u001b[32m   1581\u001b[39m     \u001b[38;5;66;03m# they are killed in the `finally` block.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1582\u001b[39m     \u001b[43mw\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43m_utils\u001b[49m\u001b[43m.\u001b[49m\u001b[43mMP_STATUS_CHECK_INTERVAL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1583\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m q \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._index_queues:\n\u001b[32m   1584\u001b[39m     q.cancel_join_thread()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/process.py:149\u001b[39m, in \u001b[36mBaseProcess.join\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m    147\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m._parent_pid == os.getpid(), \u001b[33m'\u001b[39m\u001b[33mcan only join a child process\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m    148\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m._popen \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[33m'\u001b[39m\u001b[33mcan only join a started process\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m res = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_popen\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    150\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m res \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    151\u001b[39m     _children.discard(\u001b[38;5;28mself\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/popen_fork.py:40\u001b[39m, in \u001b[36mPopen.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m     38\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m     39\u001b[39m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmultiprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconnection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m wait\n\u001b[32m---> \u001b[39m\u001b[32m40\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msentinel\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m     41\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m     42\u001b[39m \u001b[38;5;66;03m# This shouldn't block if wait() returned successfully.\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/connection.py:948\u001b[39m, in \u001b[36mwait\u001b[39m\u001b[34m(object_list, timeout)\u001b[39m\n\u001b[32m    945\u001b[39m     deadline = time.monotonic() + timeout\n\u001b[32m    947\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m948\u001b[39m     ready = \u001b[43mselector\u001b[49m\u001b[43m.\u001b[49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    949\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m ready:\n\u001b[32m    950\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m [key.fileobj \u001b[38;5;28;01mfor\u001b[39;00m (key, events) \u001b[38;5;129;01min\u001b[39;00m ready]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.9/lib/python3.11/selectors.py:415\u001b[39m, in \u001b[36m_PollLikeSelector.select\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m    413\u001b[39m ready = []\n\u001b[32m    414\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m415\u001b[39m     fd_event_list = \u001b[38;5;28mself\u001b[39m._selector.poll(timeout)\n\u001b[32m    416\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n\u001b[32m    417\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m ready\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "    print(f\"--- Epoch {epoch+1}/{epochs} ---\")\n",
    "\n",
    "    # Train\n",
    "    train_loss, train_acc = run_epoch_x_to_c(model, train_loader, attr_criterion, optimizer, is_training=True, use_aux=True, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)\n",
    "    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')\n",
    "\n",
    "    # Validate\n",
    "    if test_loader:\n",
    "        with torch.no_grad():\n",
    "            val_loss, val_acc = run_epoch_x_to_c(model, test_loader, attr_criterion, optimizer, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)\n",
    "\n",
    "        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')\n",
    "\n",
    "        # Save best model based on validation accuracy\n",
    "        if val_acc > best_val_acc:\n",
    "            print(f\"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...\")\n",
    "            best_val_acc = val_acc\n",
    "            torch.save(model, 'x_to_c_best_model.pth')\n",
    "            print(\"Model saved to x_to_c_best_model.pth\")\n",
    "\n",
    "    # Scheduler step\n",
    "    scheduler.step()\n",
    "    print(f\"Current LR: {optimizer.param_groups[0]['lr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b4baf8-ac50-4cc7-ac98-3af1a367a0da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28684261-09da-40ca-bf12-a13f99b29d04",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
